{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 132,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "\n",
    "import sys\n",
    "sys.path.append('./python')\n",
    "import caffe\n",
    "\n",
    "sys.path.append('./examples/coco_caption')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<unk>\r\n",
      "a\r\n",
      "on\r\n",
      "of\r\n",
      "the\r\n",
      "in\r\n",
      "with\r\n",
      "and\r\n",
      "is\r\n",
      "man\r\n"
     ]
    }
   ],
   "source": [
    "!head examples/coco_caption/h5_data/buffer_100/vocabulary.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8801\n"
     ]
    }
   ],
   "source": [
    "vocabulary = ['<EOS>'] + [line.strip() for line in\n",
    "                          open('examples/coco_caption/h5_data/buffer_100/vocabulary.txt').readlines()]\n",
    "print len(vocabulary)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1, 1, 8801)\n"
     ]
    }
   ],
   "source": [
    "iter_num = 110000\n",
    "net = caffe.Net('./examples/coco_caption/lstm_lm.deploy.prototxt',\n",
    "                './examples/coco_caption/lstm_lm_iter_%d.caffemodel' % iter_num, caffe.TEST)\n",
    "print net.blobs['probs'].data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def predict_single_word(net, previous_word, output='probs'):\n",
    "    cont = 0 if previous_word == 0 else 1\n",
    "    cont_input = np.array([cont])\n",
    "    word_input = np.array([previous_word])\n",
    "    net.forward(cont_sentence=cont_input, input_sentence=word_input)\n",
    "    output_preds = net.blobs[output].data[0, 0, :]\n",
    "    return output_preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "first_word_dist = predict_single_word(net, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "top_preds = np.argsort(-1 * first_word_dist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[  2  14   5  13  64  77  30  18  93 142]\n",
      "['a', 'two', 'the', 'an', 'there', 'three', 'some', 'people', 'several', 'this']\n"
     ]
    }
   ],
   "source": [
    "print top_preds[:10]\n",
    "print [vocabulary[index] for index in top_preds[:10]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['people', 'men', 'women', 'giraffes', 'zebras', 'young', 'cats', 'elephants', 'horses', 'children']\n"
     ]
    }
   ],
   "source": [
    "second_word_dist = predict_single_word(net, vocabulary.index('two'))\n",
    "print [vocabulary[index] for index in np.argsort(-1 * second_word_dist)[:10]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['standing', 'are', 'in', 'stand', 'walking', 'and', 'eating', 'that', 'walk', 'with']\n"
     ]
    }
   ],
   "source": [
    "third_word_dist = predict_single_word(net, vocabulary.index('giraffes'))\n",
    "print [vocabulary[index] for index in np.argsort(-1 * second_word_dist)[:10]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['leaves', 'from', 'grass', 'hay', 'out', 'some', 'in', 'food', 'off', 'a']\n"
     ]
    }
   ],
   "source": [
    "third_word_dist = predict_single_word(net, vocabulary.index('eating'))\n",
    "print [vocabulary[index] for index in np.argsort(-1 * second_word_dist)[:10]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def softmax(softmax_inputs, temp):\n",
    "    shifted_inputs = softmax_inputs - softmax_inputs.max()\n",
    "    exp_outputs = np.exp(temp * shifted_inputs)\n",
    "    exp_outputs_sum = exp_outputs.sum()\n",
    "    if np.isnan(exp_outputs_sum):\n",
    "        return exp_outputs * float('nan')\n",
    "    assert exp_outputs_sum > 0\n",
    "    if np.isinf(exp_outputs_sum):\n",
    "        return np.zeros_like(exp_outputs)\n",
    "    eps_sum = 1e-20\n",
    "    return exp_outputs / max(exp_outputs_sum, eps_sum)\n",
    "\n",
    "def random_choice_from_probs(softmax_inputs, temp=1):\n",
    "    # temperature of infinity == take the max\n",
    "    if temp == float('inf'):\n",
    "        return np.argmax(softmax_inputs)\n",
    "    probs = softmax(softmax_inputs, temp)\n",
    "    r = random.random()\n",
    "    cum_sum = 0.\n",
    "    for i, p in enumerate(probs):\n",
    "        cum_sum += p\n",
    "        if cum_sum >= r: return i\n",
    "    return 1  # return UNK?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def generate_sentence(net, temp=float('inf'), output='predict', max_words=50):\n",
    "    cont_input = np.array([0])\n",
    "    word_input = np.array([0])\n",
    "    sentence = []\n",
    "    while len(sentence) < max_words and (not sentence or sentence[-1] != 0):\n",
    "        net.forward(cont_sentence=cont_input, input_sentence=word_input)\n",
    "        output_preds = net.blobs[output].data[0, 0, :]\n",
    "        sentence.append(random_choice_from_probs(output_preds, temp=temp))\n",
    "        cont_input[0] = 1\n",
    "        word_input[0] = sentence[-1]\n",
    "    return sentence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2, 10, 9, 16, 6, 2, 35, 7, 2, 118, 0]\n",
      "['a', 'man', 'is', 'standing', 'in', 'a', 'field', 'with', 'a', 'frisbee', '<EOS>']\n"
     ]
    }
   ],
   "source": [
    "sentence = generate_sentence(net)\n",
    "print sentence\n",
    "print [vocabulary[index] for index in sentence]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2, 10, 9, 16, 6, 2, 35, 7, 2, 118, 0]\n",
      "['a', 'man', 'is', 'standing', 'in', 'a', 'field', 'with', 'a', 'frisbee', '<EOS>']\n"
     ]
    }
   ],
   "source": [
    "sentence = generate_sentence(net)\n",
    "print sentence\n",
    "print [vocabulary[index] for index in sentence]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 137,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2, 22, 9, 294, 7, 2, 178, 113, 11, 87, 905, 0]\n",
      "['a', 'woman', 'is', 'posing', 'with', 'a', 'cell', 'phone', 'to', 'her', 'ear', '<EOS>']\n"
     ]
    }
   ],
   "source": [
    "sentence = generate_sentence(net, temp=1.0)\n",
    "print sentence\n",
    "print [vocabulary[index] for index in sentence]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2, 28, 26, 2, 38, 209, 3, 2, 38, 152, 0]\n",
      "['a', 'person', 'holding', 'a', 'tennis', 'racket', 'on', 'a', 'tennis', 'court', '<EOS>']\n"
     ]
    }
   ],
   "source": [
    "sentence = generate_sentence(net, temp=1.0)\n",
    "print sentence\n",
    "print [vocabulary[index] for index in sentence]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2, 10, 26, 2, 38, 363, 3, 2, 38, 152, 0]\n",
      "['a', 'man', 'holding', 'a', 'tennis', 'racquet', 'on', 'a', 'tennis', 'court', '<EOS>']\n"
     ]
    }
   ],
   "source": [
    "sentence = generate_sentence(net, temp=1.5)\n",
    "print sentence\n",
    "print [vocabulary[index] for index in sentence]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2, 33, 4, 18, 12, 106, 2, 23, 7, 60, 0]\n",
      "['a', 'group', 'of', 'people', 'sitting', 'around', 'a', 'table', 'with', 'food', '<EOS>']\n"
     ]
    }
   ],
   "source": [
    "sentence = generate_sentence(net, temp=1.5)\n",
    "print sentence\n",
    "print [vocabulary[index] for index in sentence]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2, 10, 6, 2, 261, 8, 217, 16, 6, 2, 43, 0]\n",
      "['a', 'man', 'in', 'a', 'suit', 'and', 'tie', 'standing', 'in', 'a', 'room', '<EOS>']\n"
     ]
    }
   ],
   "source": [
    "sentence = generate_sentence(net, temp=3.0)\n",
    "print sentence\n",
    "print [vocabulary[index] for index in sentence]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2, 10, 26, 2, 38, 363, 3, 2, 38, 152, 0]\n",
      "['a', 'man', 'holding', 'a', 'tennis', 'racquet', 'on', 'a', 'tennis', 'court', '<EOS>']\n"
     ]
    }
   ],
   "source": [
    "sentence = generate_sentence(net, temp=3.0)\n",
    "print sentence\n",
    "print [vocabulary[index] for index in sentence]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2, 10, 9, 16, 6, 2, 35, 7, 2, 118, 0]\n",
      "['a', 'man', 'is', 'standing', 'in', 'a', 'field', 'with', 'a', 'frisbee', '<EOS>']\n"
     ]
    }
   ],
   "source": [
    "sentence = generate_sentence(net, temp=10.0)\n",
    "print sentence\n",
    "print [vocabulary[index] for index in sentence]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 144,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1993, 1074, 86, 6, 40, 4, 2, 126, 0]\n",
      "['staircase', 'laid', 'out', 'in', 'front', 'of', 'a', 'window', '<EOS>']\n"
     ]
    }
   ],
   "source": [
    "sentence = generate_sentence(net, temp=1.0)\n",
    "print sentence\n",
    "print [vocabulary[index] for index in sentence]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 146,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2, 28, 3, 2, 113, 46, 2, 129, 0]\n",
      "['a', 'person', 'on', 'a', 'phone', 'riding', 'a', 'car', '<EOS>']\n"
     ]
    }
   ],
   "source": [
    "sentence = generate_sentence(net, temp=0.8)\n",
    "print sentence\n",
    "print [vocabulary[index] for index in sentence]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 147,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2, 16, 60, 6, 136, 192, 7, 641, 16, 20, 11, 27, 0]\n",
      "['a', 'standing', 'food', 'in', 'each', 'hand', 'with', 'cattle', 'standing', 'next', 'to', 'it', '<EOS>']\n"
     ]
    }
   ],
   "source": [
    "sentence = generate_sentence(net, temp=0.8)\n",
    "print sentence\n",
    "print [vocabulary[index] for index in sentence]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[28, 236, 1042, 7, 69, 1257, 487, 1769, 0]\n",
      "['person', 'taking', 'noodles', 'with', 'other', 'homemade', 'birthday', 'cereal', '<EOS>']\n"
     ]
    }
   ],
   "source": [
    "sentence = generate_sentence(net, temp=0.6)\n",
    "print sentence\n",
    "print [vocabulary[index] for index in sentence]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[5623, 1087, 15, 6888, 472, 361, 8634, 8, 7241, 3, 77, 299, 935, 1296, 15, 12, 5165, 2867, 3979, 743, 4991, 4470, 640, 9, 259, 2308, 4386, 2552, 3797, 2448, 15, 3617, 5364, 4267, 4549, 8086, 176, 2529, 6434, 5445, 370, 7959, 5672, 1742, 4041, 4258, 1153, 8, 610, 2044]\n",
      "['chilli', 'frosting', ',', 'medley', 'salad', 'items', 'sideboard', 'and', 'garnishes', 'on', 'three', 'colorful', 'gold', 'desserts', ',', 'sitting', 'knifes', 'need', 'workspace', 'where', 'exchanging', 'hoses', 'left', 'is', 'pink', 'clearing', 'obstacles', 'vandalized', 'idly', 'afternoon', ',', 'halloween', 'rich', 'fixed', 'aid', 'advertise', 'light', 'times', 'delicate', 'dealership', 'like', 'snowsuits', 'florida', 'than', 'ornamental', 'dr', 'curtains', 'and', 'multiple', 'electrical']\n"
     ]
    }
   ],
   "source": [
    "sentence = generate_sentence(net, temp=0.5)\n",
    "print sentence\n",
    "print [vocabulary[index] for index in sentence]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
